Transformers DataCollatorForSeq2Seq

DataCollatorForSeq2Seq 是一个特殊的数据整理工具,用于序列到序列Seq2Seq)任务,如机器翻译文本摘要等。它将输入和目标序列进行正确的填充和处理,以便它们可以被用于训练 Transformer 模型。

导入库和模块

from transformers import DataCollatorForSeq2Seq, BertTokenizer

创建 tokenizer 和 DataCollator

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
data_collator = DataCollatorForSeq2Seq(tokenizer, model_type="bert")

我们首先加载了一个预训练的 BERT tokenizer,然后创建了一个 DataCollatorForSeq2Seq 实例,它将用于处理我们的数据。

准备数据

# This is a toy example. In practice, you would load your data from a file, preprocess it, etc.
examples = [{
    "input_ids": tokenizer.encode("Hello, world!", 
							      return_tensors="pt"),
    "labels": tokenizer.encode("Hello, world!", 
							   return_tensors="pt")
}]

这里我们创建了一个包含单个样本的数据集。每个样本都包含 "input_ids" 和 "labels" 字段,分别表示输入序列和目标序列。

使用 DataCollator

batch = data_collator(examples)

DataCollatorForSeq2Seq 的主要功能是将样本组合成一个批次,以便可以一次将多个样本传递给模型。在这个例子中,我们的批次只包含一个样本,但在实际使用中,批次通常会包含多个样本。

注意:DataCollatorForSeq2Seq 在处理数据时,会自动进行适当的填充,以确保所有的序列都有相同的长度。这是因为 Transformer 模型需要输入的所有序列都有相同的长度。然而,DataCollatorForSeq2Seq 不会对 "labels" 字段进行填充,因为在计算损失时,我们通常不希望考虑填充的部分。


本文作者:Maeiee

本文链接:Transformers DataCollatorForSeq2Seq

版权声明:如无特别声明,本文即为原创文章,版权归 Maeiee 所有,未经允许不得转载!


喜欢我文章的朋友请随缘打赏,鼓励我创作更多更好的作品!